#include "statsxx/machine_learning/NeuralNet.hpp"

// STL
#include <fstream>                              // std::ofstream
#include <utility>                              // std::pair

// jScience
#include "jutility.hpp"                         // get_n_rand_unique_elements()
#include "jScience/math/utility.hpp"            // ngroups()
#include "jScience/stats/data.hpp"              // shuffle_data()

// stats++
#include "statsxx/dataset.hpp"       // DataSet
#include "statsxx/machine_learning/NeuralNet/utility.hpp" // fit_validation_error()


static DataSet prob_to_binary(
                              DataSet ds
                              );


//
// DESC:
//
// INPUT:
//          DataSet ds_tr  :: Training dataset
//          DataSet ds_val :: Validation dataset
//
inline void NEURAL_NET::backpropagation(
                                        const int      max_epochs,
                                        // -----
                                        const bool     early_stopping,
                                        // -----
                                        const double   lr,
                                        // -----
                                        const double   momentum,
                                        // -----
                                        const double   weight_penalty,
                                        // -----
                                        const DataSet &ds_tr,
                                        const DataSet &ds_val,
                                        // -----
                                        const int      npts_per_batch
                                        )
{
//    static unsigned int max_epochs = 1000;

    // note: the following (three) default settings were taken 10/2/15 from those for the DBN
//    static double lr               = 0.1;       // learning rate
//   static double momentum         = 0.05;       // momentum term [0,1]
//    static double weight_penalty   = 0.001;     // (weight) penalization term

//    static int    nbatches         = 10;

//    static bool   stochastic       = true;    // stochastic learning

    int nbatches;
    if(npts_per_batch < 0)
    {
        nbatches = 1;
    }
    else
    {
        nbatches = ds_tr.pt.size()/npts_per_batch;
    }

    //*********************************************************

    std::vector<double> w;
    for( auto &link : m_links )
    {
        w.push_back(link.weight);
    }

    std::vector<double> w_inc(w.size(), 0.0);

    // setup indices for batches
    std::vector<int> shuffle_idx(ds_tr.pt.size());
    std::iota(shuffle_idx.begin(), shuffle_idx.end(), 0);

    std::vector<int> batch_begin;
    std::vector<int> batch_end;
    std::vector<int> batch_size;
    std::tie(batch_begin, batch_end, batch_size) = ngroups(ds_tr.pt.size(), nbatches);

    // CONVG/STORAGE
    std::vector<double> Etr, Etr_stddev, Ev, Ev_stddev;
    std::vector<std::vector<double>> w_store;

    std::ofstream ofs("./training_errs.dat", std::ios::app);

    for(int epoch = 0; epoch < max_epochs; ++epoch)
    {

/*
        DataSet _ds_tr = prob_to_binary(
                                        ds_tr
                                        );

        DataSet _ds_val = prob_to_binary(
                                         ds_val
                                         );
*/

        DataSet _ds_tr = ds_tr;

        DataSet _ds_val = ds_val;

//        std::cout << "here" << std::endl;
//        std::cout << "_ds_tr.pt.size() : " << _ds_tr.pt.size() << std::endl;
//        std::cout << "_ds_val.pt.size(): " << _ds_val.pt.size() << std::endl;

        // note: every epoch we shuffle the data (the batches) in order to remove any possible serial correlation or accidental groupings of cases
        shuffle_data(shuffle_idx);

        for(int batch = 0; batch < nbatches; ++batch)
        {
            std::vector<double> dEdw(w.size(), 0.0);

            for(int i = batch_begin[batch]; i <= batch_end[batch]; ++i)
            {
                std::vector<double> dEdw_i;

                DataSet ds_tmp;
                ds_tmp.pt.push_back( _ds_tr.pt[shuffle_idx[i]] );

                calulate_Ederiv( ds_tmp, dEdw_i );

                for(auto j = 0; j < dEdw.size(); ++j)
                {
                    dEdw[j] += dEdw_i[j];
                }
            }

            // note: normalization ensures that we can consistently define all penalties, learning rates, etc.
            for(auto j = 0; j < dEdw.size(); ++j)
            {
                dEdw[j] /= batch_size[batch];
            }

            // PENALIZE LARGE WEIGHTS
            for(auto j = 0; j < dEdw.size(); ++j)
            {
                dEdw[j] += weight_penalty*w[j];
            }

            // COMPUTE INCREMENTS, AND UPDATE WEIGHTS FOR THIS BATCH
            for(auto j = 0; j < dEdw.size(); ++j)
            {
                w_inc[j] = momentum*w_inc[j] - lr*dEdw[j];
            }

            for(auto j = 0; j < dEdw.size(); ++j)
            {
                w[j] += w_inc[j];
            }

            assign_weights(w);
        }

        // ***
        double Emean;
        double Estddev;
        std::vector<std::vector<double>> NN_out;

        std::tie( Emean, Estddev, NN_out) = parse_TS(
                                                     _ds_tr
                                                     );
        Etr.push_back(Emean);
        Etr_stddev.push_back(Estddev);

        std::tie( Emean, Estddev, NN_out) = parse_TS(
                                                     _ds_val
                                                     );
        Ev.push_back(Emean);
        Ev_stddev.push_back(Estddev);

        ofs << (epoch+1) << "   " << Etr.back() << "   " << Etr_stddev.back() << "   " << Ev.back() << "   " << Ev_stddev.back() << std::endl;

        w_store.push_back(w);
        // ****
    }

    ofs.close();

    if( early_stopping )
    {
        // FIT THE VALIDATION ERROR TO A POLYNOMIAL, FINDING THE MINIMUM

        // standard deviations of Ev are often very large, though it is probably rigorously correct to use them
        std::vector<double>::size_type indx;
        std::vector<double>            f;
        std::tie(
                 f,
                 indx
                 ) = fit_validation_error(
                                          Ev,
                                          Ev_stddev
                                          );

        std::ofstream ofs2("fit_valid_err.dat");

        ofs2 << "using point: " << indx << '\n';

        for( auto i = 0; i < f.size(); ++i )
        {
            // NOTE: (i+1) is used so that f aligns with the errors in training_err.dat
            ofs2 << (i+1) << "   " << f[i] << '\n';
        }

        ofs2.close();

        assign_weights(w_store[indx]);
    }
    else
    {
        assign_weights(w_store.back());
    }
}

/*
static DataSet prob_to_binary(
                              DataSet ds
                              )
{
    for( auto &pt : ds.pt )
    {
        for( auto i = 0; i < pt.out.size(); ++i )
        {
            double r = rand_num_uniform_Mersenne_twister(0., 1.);

            if( r < pt.out[i] )
            {
                pt.out[i] = 1.;
            }
            else
            {
                pt.out[i] = 0.;
            }
        }
    }

    return ds;
}
*/
